# coding=utf-8
'''
Created on 2020-6-1
@author: jiangao
Project: Brovey图像融合方法
'''
import numpy as np
import cv2
import scipy.misc as smi
from osgeo import gdal
from PIL import Image


def gdal_open(path):
    """
    读取图像函数
    输入：图像路径
    返回：np.ndArray格式的三维数组
    """
    data = gdal.Open(path)
    col = data.RasterXSize  # 读取图像长度
    print(col)
    row = data.RasterYSize  # 读取图像宽度
    print(row)
    band1 = data.GetRasterBand(3)
    red = band1.ReadAsArray()# 读取图像第一波段并转换为数组

    data_array_g = data.GetRasterBand(2).ReadAsArray()  # 读取图像第二波段并转换为数组
    data_array_b = data.GetRasterBand(1).ReadAsArray() # 读取图像第三波段并转换为数组
    data_array = np.array((red, data_array_g, data_array_b))


    return data_array

def gdal_openpAN(path):
    """
    读取图像函数
    输入：图像路径
    返回：np.ndArray格式的三维数组
    """
    data = gdal.Open(path)
    col = data.RasterXSize  # 读取图像长度
    print(col)
    row = data.RasterYSize  # 读取图像宽度
    print(row)
    band1 = data.GetRasterBand(1)
    red = band1.ReadAsArray()# 读取图像第一波段并转换为数组
    data_array1 = np.array(red)
    return data_array1
def imresize(data_low, data_high):
    """
    图像缩放函数
    输入：np.ndArray格式的三维数组
    返回：np.ndArray格式的三维数组
    """
    band = 1
    col, row = data_high.shape
    data = np.zeros(((band, col, row)))
    for i in range(0, band):
        data[i] = smi.imresize(data_low[i], (col, row))


    return data


def brovey(data_low, data_high):
    """
    色彩标准化融合函数
    输入：np.ndArray格式的三维数组
    返回：可绘出图像的utf-8格式的三维数组
    """
    band, col, row = data_low.shape
    total = 0
    for i in range(0, band):
        total = total + data_low[i]
    RGB = np.zeros(((band, col, row)))
    for i in range(0, band):
        RGB[i] = data_low[i] * data_high[i] / total
    min_val = np.min(RGB.ravel())
    max_val = np.max(RGB.ravel())
    RGB = np.uint8((RGB.astype(np.float) - min_val) / (max_val - min_val) * 255)
    RGB = Image.fromarray(cv2.merge([RGB[0], RGB[1], RGB[2]]))


    return RGB


def main(path_low, path_high):
    data_low = gdal_open(path_low)
    data_high = gdal_openpAN(path_high)
    data_low = imresize(data_low, data_high)
    RGB = brovey(data_low, data_high)
    RGB.save(r'G:\01工作空间\Brovey.png', 'png')


if __name__ == "__main__":
    path_low = r'G:\01工作空间\test1\TEST\GF2_PMS1_E110.7_N21.4_20210223_L1A0005501536-MSS1.tiff'
    path_high = r'G:\01工作空间\test1\TEST\GF2_PMS1_E110.7_N21.4_20210223_L1A0005501536-PAN1.tiff'
    main(path_low, path_high)